import os
import shutil
import warnings

import KratosMultiphysics as Kratos
import KratosMultiphysics.GeoMechanicsApplication as KratosGeo
import KratosMultiphysics.KratosUnittest as KratosUnittest
from KratosMultiphysics.GeoMechanicsApplication.set_parameter_field_process import SetParameterFieldProcess

import test_helper


class KratosGeoMechanicsParameterFieldTests(KratosUnittest.TestCase):
    """
    This class contains tests which check if custom parameter fields are correctly added to the model
    """

    def test_variable_exists_in_python(self):
        """
        Test to check if the variable exists in python

        """
        variable_name = "UMAT_PARAMETERS"
        self.assertTrue(hasattr(KratosGeo, variable_name))

    def test_parameter_field_with_function_umat_parameters(self):
        """
        Test to check if values from a function defined parameter field are correctly added to each individual element

        """

        test_name = os.path.join("test_parameter_field", "parameter_field_input_umat_parameters")
        file_path = test_helper.get_file_path(test_name)

        # run simulation
        simulation = test_helper.run_kratos(file_path)

        # get element centers
        elements = simulation._list_of_output_processes[0].model_part.Elements
        center_coords = [element.GetGeometry().Center() for element in elements]

        # Get test results
        results = test_helper.get_on_integration_points(simulation, KratosGeo.UMAT_PARAMETERS)
        default_umat_parameters = [1e3, 0.3, 0.0, 30.0, 0.0, 0.0, 1.0, 0.0]
        # assert
        for center_coord, res in zip(center_coords, results):
            expected_res = 20000 * center_coord[0] + 30000 * center_coord[1]
            self.assertAlmostEqual(expected_res, res[0][0])
            self.assertAlmostEqual(default_umat_parameters[1], res[0][1])
            self.assertAlmostEqual(default_umat_parameters[3], res[0][3])
            self.assertAlmostEqual(default_umat_parameters[4], res[0][4])
            self.assertAlmostEqual(default_umat_parameters[5], res[0][5])
            self.assertAlmostEqual(default_umat_parameters[6], res[0][6])
            self.assertAlmostEqual(default_umat_parameters[7], res[0][7])
            self.assertAlmostEqual(expected_res, res[0][2])

    def test_parameter_field_with_python_umat_parameters(self):
        """
        Test to check if values from a parameter field generated by a user defined python script are correctly added to
        each individual element

        """
        test_name = os.path.join("test_parameter_field", "parameter_field_python_umat_parameters")
        file_path = test_helper.get_file_path(test_name)

        simulation = test_helper.run_kratos(file_path)

        # get element centers
        elements = simulation._list_of_output_processes[0].model_part.Elements
        center_coords = [element.GetGeometry().Center() for element in elements]

        # Get test results
        results = test_helper.get_on_integration_points(simulation, KratosGeo.UMAT_PARAMETERS)

        # assert
        for center_coord, res in zip(center_coords, results):
            expected_res = 20000 * center_coord[0] + 30000 * center_coord[1]
            self.assertAlmostEqual(expected_res, res[0][0])

    def test_parameter_field_with_json_umat_parameters(self):
        """
        Test to check if values from a parameter field stored in a json file are correctly added to
        each individual element

        """
        test_name = os.path.join("test_parameter_field", "parameter_field_json_umat_parameters")
        file_path = test_helper.get_file_path(test_name)

        # run simulation
        simulation = test_helper.run_kratos(file_path)

        # Get element centers
        elements = simulation._list_of_output_processes[0].model_part.Elements
        center_coords = [element.GetGeometry().Center() for element in elements]

        # Get test results
        results = test_helper.get_on_integration_points(simulation, KratosGeo.UMAT_PARAMETERS)

        # assert
        for center_coord, res in zip(center_coords, results):
            expected_res = 20000 * center_coord[0] + 30000 * center_coord[1]
            self.assertAlmostEqual(expected_res, res[0][0])

    def test_parameter_field_with_function(self):
        """
        Test to check if values from a function defined parameter field are correctly added to each individual element

        """

        test_name = os.path.join("test_parameter_field", "parameter_field_input")
        file_path = test_helper.get_file_path(test_name)

        # run simulation
        simulation = test_helper.run_kratos(file_path)

        # get element centers
        elements = simulation._list_of_output_processes[0].model_part.Elements
        center_coords = [element.GetGeometry().Center() for element in elements]

        # Get test results
        results = test_helper.get_on_integration_points(simulation, Kratos.YOUNG_MODULUS)

        # assert
        for center_coord, res in zip(center_coords, results):
            expected_res = 20000 * center_coord[0] + 30000 * center_coord[1]
            self.assertAlmostEqual(expected_res, res[0])

    def test_parameter_field_with_python(self):
        """
        Test to check if values from a parameter field generated by a user defined python script are correctly added to
        each individual element

        """
        test_name = os.path.join("test_parameter_field", "parameter_field_python")
        file_path = test_helper.get_file_path(test_name)

        custom_script_name = "custom_field.py"
        custom_python_file = os.path.join(file_path, custom_script_name)

        # copy user defined python script to installation folder
        new_custom_script_path = os.path.join(os.path.dirname(KratosGeo.__file__), "user_defined_scripts")
        try:
            shutil.copy(custom_python_file, new_custom_script_path)
        except Exception as e:
            print(f"Source file path: {custom_python_file}")
            print(f"Destination file path: {new_custom_script_path}")
            print(f"Error occurred while copying the file {e}.")
            raise

        # run simulation
        simulation = test_helper.run_kratos(file_path)

        # get element centers
        elements = simulation._list_of_output_processes[0].model_part.Elements
        center_coords = [element.GetGeometry().Center() for element in elements]

        # Get test results
        results = test_helper.get_on_integration_points(simulation, Kratos.YOUNG_MODULUS)

        # assert
        for center_coord, res in zip(center_coords, results):
            expected_res = 20000 * center_coord[0] + 30000 * center_coord[1]
            self.assertAlmostEqual(expected_res, res[0])

        # remove user defined python script from installation folder
        os.remove(os.path.join(new_custom_script_path, custom_script_name))

    def test_parameter_field_with_json(self):
        """
        Test to check if values from a parameter field stored in a json file are correctly added to
        each individual element

        """
        test_name = os.path.join("test_parameter_field", "parameter_field_json")
        file_path = test_helper.get_file_path(test_name)

        # run simulation
        simulation = test_helper.run_kratos(file_path)

        # Get element centers
        elements = simulation._list_of_output_processes[0].model_part.Elements
        center_coords = [element.GetGeometry().Center() for element in elements]

        # Get test results
        results = test_helper.get_on_integration_points(simulation, Kratos.YOUNG_MODULUS)

        # assert
        for center_coord, res in zip(center_coords, results):
            expected_res = 20000 * center_coord[0] + 30000 * center_coord[1]
            self.assertAlmostEqual(expected_res, res[0])

    def test_parameter_field_with_invalid_json(self):
        """
        Test to check if values from a parameter field stored in a json file are correctly added to
        each individual element

        """
        test_name = os.path.join("test_parameter_field", "invalid_parameter_field_json")
        file_path = test_helper.get_file_path(test_name)

        # run simulation and assert if correct exception is raised
        with self.assertRaises(RuntimeError) as cm:
            test_helper.run_kratos(file_path)

        self.assertTrue(r'Error: The parameter field does not have the same size as '
                        r'the amount of elements within the model part!' in str(cm.exception))

    def test_GetVariableBasedOnString(self):
        """
        Test to check if the variable is correctly retrieved from the imported modules
        """

        # dummy variables with YOUNG_MODULUS, which is a variable which is exported to the python module
        settings = Kratos.Parameters("""{
            "model_part_name": "test",
            "variable_name": "YOUNG_MODULUS",
            "dataset": "dummy",
            "func_type": "json_file",
            "function": "dummy",
            "dataset_file_name": "test_file"
        }""")

        # initialize the set parameter field process
        model = Kratos.Model()
        model.CreateModelPart("test")
        process = SetParameterFieldProcess(model, settings)

        variable = process.GetVariableBasedOnString()

        assert variable == Kratos.YOUNG_MODULUS

    def test_GetVariableBasedOnString_non_existing_variable_in_python(self):
        """
        Test to check if a warning is raised when a variable is not present in the imported modules
        """

        # dummy variables with DENSITY_SOLID_dummy, which is a variable which is not exported to the python module
        settings = Kratos.Parameters("""{
            "model_part_name": "test",
            "variable_name": "DENSITY_SOLID_dummy",
            "dataset": "dummy",
            "func_type": "json_file",
            "function": "dummy",
            "dataset_file_name": "test_file"
        }""")

        # initialize the set parameter field process
        model = Kratos.Model()
        model.CreateModelPart("test")
        process = SetParameterFieldProcess(model, settings)

        # catch the warnings for the test
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")  # Ensure warnings are triggered

            assert process.GetVariableBasedOnString() is None

        # Check that a warning was raised
        assert len(w) == 1
        assert issubclass(w[-1].category, UserWarning)
        assert str(w[-1].message) == "The variable: DENSITY_SOLID_dummy is not present within the imported modules"


if __name__ == '__main__':
    KratosUnittest.main()
